#include <iostream>
#include <cstdio>
#define ll long long

using namespace std;

const ll mod=998244353;
ll read()
{
	ll x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}
int main()
{
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	ll n=read(),m=read();
	if(n<m)swap(n,m);
	ll ans=n*(m-1)+(n-1);
	printf("%lld",ans%mod);
	return 0;
}
